"""
@author: Viet Nguyen <nhviet1009@gmail.com>
"""

import gymnasium as gym
from gymnasium.spaces import Box
from gymnasium import Wrapper
import cv2
import numpy as np
import subprocess as sp
import torch.multiprocessing as mp


class Monitor:
    def __init__(self, width, height, saved_path):

        self.command = ["ffmpeg", "-y", "-f", "rawvideo", "-vcodec", "rawvideo", "-s", "{}X{}".format(width, height),
                        "-pix_fmt", "rgb24", "-r", "60", "-i", "-", "-an", "-vcodec", "mpeg4", saved_path]
        try:
            self.pipe = sp.Popen(self.command, stdin=sp.PIPE, stderr=sp.PIPE)
        except FileNotFoundError:
            pass

    def record(self, image_array):
        self.pipe.stdin.write(image_array.tostring())


def process_frame(frame):
    if frame is not None:
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(frame, (84, 84))[None, :, :] / 255.
        return frame
    else:
        return np.zeros((1, 84, 84))


class CustomReward(Wrapper):
    def __init__(self, env=None, world=None, stage=None, monitor=None):
        super(CustomReward, self).__init__(env)
        self.observation_space = Box(low=0, high=255, shape=(1, 84, 84))
        self.curr_score = 0
        self.current_x = 40
        self.world = world
        self.stage = stage
        if monitor:
            self.monitor = monitor
        else:
            self.monitor = None

    def step(self, action):
        state, reward, done, trunc, info = self.env.step(action)
        if self.monitor:
            self.monitor.record(state)
        state = process_frame(state)
        reward += (info["score"] - self.curr_score) / 40.
        self.curr_score = info["score"]
        if done or trunc:
            if info["flag_get"]:
                reward += 50
            else:
                reward -= 50
        if self.world == 7 and self.stage == 4:
            if (506 <= info["x_pos"] <= 832 and info["y_pos"] > 127) or (
                    832 < info["x_pos"] <= 1064 and info["y_pos"] < 80) or (
                    1113 < info["x_pos"] <= 1464 and info["y_pos"] < 191) or (
                    1579 < info["x_pos"] <= 1943 and info["y_pos"] < 191) or (
                    1946 < info["x_pos"] <= 1964 and info["y_pos"] >= 191) or (
                    1984 < info["x_pos"] <= 2060 and (info["y_pos"] >= 191 or info["y_pos"] < 127)) or (
                    2114 < info["x_pos"] < 2440 and info["y_pos"] < 191) or info["x_pos"] < self.current_x - 500:
                reward -= 50
                done = True
        if self.world == 4 and self.stage == 4:
            if (info["x_pos"] <= 1500 and info["y_pos"] < 127) or (
                    1588 <= info["x_pos"] < 2380 and info["y_pos"] >= 127):
                reward = -50
                done = True

        self.current_x = info["x_pos"]
        return state, reward / 10., done, trunc, info

    def reset(self):
        self.curr_score = 0
        self.current_x = 40
        obs, info = self.env.reset()
        return process_frame(obs), info


class CustomSkipFrame(Wrapper):
    def __init__(self, env, skip=4):
        super(CustomSkipFrame, self).__init__(env)
        self.observation_space = Box(low=0, high=255, shape=(skip, 84, 84))
        self.skip = skip
        self.states = np.zeros((skip, 84, 84), dtype=np.float32)

    def step(self, action):
        total_reward = 0
        last_states = []
        for i in range(self.skip):
            state, reward, done, trunc, info = self.env.step(action)
            total_reward += reward
            if i >= self.skip / 2:
                last_states.append(state)
            if done:
                self.reset()
                return self.states[None, :, :, :].astype(np.float32), total_reward, done, trunc, info
        max_state = np.max(np.concatenate(last_states, 0), 0)
        self.states[:-1] = self.states[1:]
        self.states[-1] = max_state
        return self.states[None, :, :, :].astype(np.float32), total_reward, done, trunc, info

    def reset(self):
        state, info = self.env.reset()
        self.states = np.concatenate([state for _ in range(self.skip)], 0)
        return self.states[None, :, :, :].astype(np.float32), info

import ptan

class RewardPenaltyWrapper(gym.Wrapper):
    def __init__(self, env, frame_penalty=-0.1, life_loss_penalty=-10):
        super(RewardPenaltyWrapper, self).__init__(env)
        self.frame_penalty = frame_penalty
        self.life_loss_penalty = life_loss_penalty
        self.previous_lives = 0

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.previous_lives = info.get('lives', 0)  # 初始生命值
        return obs, info

    def step(self, action):
        obs, reward, done, truncated, info = self.env.step(action)

        reward /= 50  # 缩放奖励

        # 处理生命减少时的惩罚
        current_lives = info.get('lives', self.previous_lives)
        if current_lives < self.previous_lives:
            reward += self.life_loss_penalty
            self.previous_lives = current_lives
        elif current_lives > self.previous_lives:
            reward -= self.life_loss_penalty
            self.previous_lives = current_lives

        return obs, reward, done, truncated, info
    
import collections

class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        super(FrameStack, self).__init__(env)
        self.k = k
        self.frames = collections.deque(maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(shp[0] * k, *shp[1:]), dtype=np.float32
        )

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        for _ in range(self.k):
            self.frames.append(obs)
        return self._get_obs(), info

    def step(self, action):
        obs, reward, done, truncated, info = self.env.step(action)
        self.frames.append(obs)
        return self._get_obs(), reward, done, truncated, info

    def _get_obs(self):
        # changed code: instead of LazyFrames, stack with NumPy
        return np.concatenate(list(self.frames), axis=0)
    

def create_train_env(env_id, output_path=None, render_mode=None):
    if render_mode is None:
        env = gym.make(env_id, frameskip=4, repeat_action_probability=0.0)
    else:
        env = gym.make(env_id, frameskip=4, repeat_action_probability=0.0, render_mode=render_mode)

    # 增强初始化
    env = ptan.common.wrappers.NoopResetEnv(env, noop_max=30)
    # 跳帧包装器
    # env = ptan.common.wrappers.MaxAndSkipEnv(env, skip=4)

    # if 'FIRE' in env.unwrapped.get_action_meanings():
    #     env = ptan.common.wrappers.FireResetEnv(env)
    env = ptan.common.wrappers.ProcessFrame84(env)
    env = ptan.common.wrappers.ImageToPyTorch(env)
    env = FrameStack(env, 4)
    env = RewardPenaltyWrapper(env)
    return env


class MultipleEnvironments:
    def __init__(self, env_id, num_envs, output_path=None):
        self.agent_conns, self.env_conns = zip(*[mp.Pipe() for _ in range(num_envs)])
        env = create_train_env(env_id, output_path=output_path)
        self.num_states = env.observation_space.shape[0]
        self.obs_size = env.observation_space.shape
        self.num_actions = env.action_space.n
        for index in range(num_envs):
            process = mp.Process(target=self.run, args=(index, env_id, output_path))
            process.start()
            # self.env_conns[index].close()

    def run(self, index, env_id, output_path):
        self.agent_conns[index].close()
        env = create_train_env(env_id, output_path=output_path)
        while True:
            request, action = self.env_conns[index].recv()
            if request == "step":
                step_result = env.step(action.item())
                self.env_conns[index].send((np.expand_dims(step_result[0], axis=0), step_result[1], step_result[2], step_result[3], step_result[4]))
            elif request == "reset":
                reset_result = env.reset()
                self.env_conns[index].send((np.expand_dims(reset_result[0], axis=0), reset_result[1]))
            else:
                raise NotImplementedError